反演+矩阵树

首先题目要求的是 Ti=1n1wei×gcd(we1,...,wen1)\sum\limits_{T}\sum\limits_{i=1}^{n-1}w_{e_{i}}\times gcd(w_{e_1},...,w_{e_{n-1}})

很明显的可以用反演,也可以直接套 ϕ1=id\phi*1=id

那么就可以得出 d=1max(w)ϕ(d)Td(gcd(weT))i=1n1wei\sum\limits_{d=1}^{max(w)}\phi(d)\sum\limits_{T}^{d|(gcd(w_{e\in T}))}\sum\limits_{i=1}^{n-1}w_{e_i}

后面部分需要用到矩阵树定理,但矩阵树求得是 Ti=1n1wei\sum\limits_T\prod\limits_{i=1}^{n-1}w_{e_i},和上面的不一样,所以需要稍作转换

考虑将边权用一个一次函数 wx+1wx+1 表示,在模 x2x^2 下作乘法,那么你会发现一次项系数就是边权和

所以不妨定义一个多项式四则运算,加减直接对应加减就行了

乘法 (ax+b)(cx+d)=(ad+bc)x+bd(ax+b)(cx+d)=(ad+bc)x+bd

除法,考虑 (cx+d)(cx+d) 的逆元,即我们需要求 (Ax+B)(cx+d)1(mod x2)(Ax+B)(cx+d)\equiv 1(mod~x^2)

Bd=1B=1dBd=1\Rightarrow B=\frac{1}{d}

Adx+Bcx=0A=cd2Adx+Bcx=0\Rightarrow A=-\frac{c}{d^2}

那么就可以得到 ax+bcx+d=(ax+b)(cd2x+1d)=adbcd2x+bd\frac{ax+b}{cx+d}=(ax+b)(-\frac{c}{d^2}x+\frac{1}{d})=\frac{ad-bc}{d^2}x+\frac{b}{d}

直接高斯消元即可,时间复杂度 O(n3max(w))O(n^3max(w)),有点勉强,这里有一个优化,即加边大于等于 n1n-1 才进行矩阵树,这样就是 O(144n4)O(144n^4)

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#define MOD (998244353)
using namespace std;
typedef long long ll;

ll qmi(ll a, ll b);
int n, m;
struct num
{
ll x, y;
num operator + (const num a) const
{return {(x + a.x) % MOD, (y + a.y) % MOD};}
num operator - (const num a) const
{return {(x - a.x + MOD) % MOD, (y - a.y + MOD) % MOD};}
num operator * (const num a) const
{return {(x * a.y + y * a.x) % MOD, y * a.y % MOD};}
num operator / (const num a) const
{
ll inv = qmi(a.y, MOD - 2);
return {((x * a.y - y * a.x) % MOD + MOD) * inv % MOD * inv % MOD, y * inv % MOD};
}
} a[35][35];
struct edge
{
int u, v, w;
} e[1005];

ll qmi(ll a, ll b)
{
ll res = 1;
while (b)
{
if (b & 1)
res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}

int gauss()
{
num res = num({0, 1});
int w = 1;
for (int i = 1; i < n; i++)
{
for (int j = i + 1; j < n; j++)
{
if (a[j][i].y)
{
swap(a[j], a[i]), w = -w;
break;
}
}
num inv = num({0, 1}) / a[i][i];
for (int j = i + 1; j < n; j++)
{
num d = a[j][i] * inv;
for (int k = i; k < n; k++)
a[j][k] = a[j][k] - a[i][k] * d;
}
}
for (int i = 1; i < n; i++)
res = res * a[i][i];
return w > 0 ? res.x : (num({0, 0}) - res).x;
}

int phi(int x)
{
int res = x;
for (int i = 2; i <= x; i++)
{
if (x % i == 0)
res -= res / i;
while (x % i == 0)
x /= i;
}
if (x > 1)
res -= res / x;
return res;
}

int main()
{
int mx = 0, ans = 0;
scanf("%d%d", &n, &m);
for (int i = 1; i <= m; i++)
{
scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].w);
mx = max(mx, e[i].w);
}
for (int i = 1; i <= mx; i++)
{
memset(a, 0, sizeof(a));
int tot = 0;
for (int j = 1; j <= m; j++)
{
if (e[j].w % i)
continue;
tot++;
int u = e[j].u, v = e[j].v, w = e[j].w;
num P = num({w, 1});
a[u][u] = a[u][u] + P, a[v][v] = a[v][v] + P;
a[u][v] = a[u][v] - P, a[v][u] = a[v][u] - P;
}
if (tot < n - 1)
continue;
int t = gauss();
ans = (ans + (ll)phi(i)* t % MOD) % MOD;
}
printf("%d", ans);
return 0;
}